import sigpy as sp
import sigpy.mri as mr
import sigpy.plot as pl
import scipy.io as sio
import numpy as np
import matplotlib.pyplot as plt
from utils import undersample
from utils import ss_combine
from skimage import metrics
import numpy as np

# Load data
mat = sio.loadmat('brain_8ch.mat')
im = np.transpose(mat['im'], (2, 0, 1))  # [Coils, X, Y]
map = np.transpose(mat['map'], (2, 0, 1))  # Sensitivity maps [Coils, X, Y]


def gfactor(coil_map, Rx, Ry):
  Nc, Nx, Ny = coil_map.shape
  Nrx = Nx // Rx
  Nry = Ny // Ry
  g = np.zeros((Nx, Ny))
  for i in range(Nx):
        for j in range(Ny):
            if np.all(abs(coil_map[:, i, j]) < 1e-6):
                g[i, j] = 0
                continue
            sens = []
            for k in range(Rx):
                for l in range(Ry):
                    ndx = (i + k * Nrx) % Nx
                    ndy = (j + l * Nry) % Ny
                    sens.append(coil_map[:, ndx, ndy])
            sens_mat = np.array(sens).T
            if sens_mat.shape[1] > 0:
                s = np.dot(sens_mat.conj().T, sens_mat)
                si = np.linalg.pinv(s)
                g[i, j] = np.sqrt(np.trace(s @ si))
  return g

def sense(ima, map, Rx, Ry):
    Nc, Nx, Ny = map.shape  
    Nrx = Nx // Rx
    Nry = Ny // Ry 
    im = np.zeros((Nx, Ny), dtype=ima.dtype)  

    for i in range(Nx):
        for j in range(Ny):
            if abs(map[0, i, j]) < 1e-6:
                im[i, j] = 0 
                continue
            s_list = []  
            for k in range(Rx):
                for l in range(Ry):
                    ndx = (i + k * Nrx) % Nx
                    ndy = (j + l * Nry) % Ny             
                    CT = map[:, ndx, ndy] 
                    if (k == 0 and l == 0) or abs(CT[0]) > 1e-6:
                        s_list.append(CT) 
            s = np.column_stack(s_list)  
            scs = s.conj().T @ s 
            if np.linalg.cond(scs) < 1e12: 
                scsi = np.linalg.pinv(scs) 
                m = ima[:, i, j]  
                mr = scsi @ s.conj().T @ m 
                im[i, j] = mr[0] 
    return im

# def undersample(im, Rx, Ry):
#     Nc, Nx, Ny = im.shape
#     ima = np.zeros((Nc, Nx, Ny), dtype=complex)
#     ma = np.zeros((Nc, Nx, Ny), dtype=complex)

#     msk = np.zeros((Nx, Ny))
#     msk[::Rx, ::Ry] = 1
    
#     for i in range(Nc):
#         ma[i, :, :] = np.fft.fft2(im[i, :, :]) * msk
#         ima[i, :, :] = np.fft.ifft2(ma[i, :, :])
    
#     return ima
# Undersampling
Rx, Ry = 2, 1
imu = undersample(im, Rx, Ry)
imu_p = pl.ImagePlot(imu, z=0, hide_axes=True)    

#g_map = gfactor(map, Rx, Ry)
im_sense = sense(imu, map, Ry, Rx)
im_sense = np.abs(im_sense)
im_sense = im_sense / np.max(im_sense)
plt.imshow(np.abs(im_sense), cmap='gray')
plt.savefig("sense.png")

# im_ss = sense(im, map, 1, 1)
im_ss = sense(im, map, 1, 1)
im_ss = np.abs(im_ss)
im_ss = im_ss / np.max(im_ss)
plt.imshow(np.abs(im_ss), cmap='gray')
plt.savefig("sense_gt.png")

data_range = np.max(im_sense) - np.min(im_sense)
pSNR = metrics.peak_signal_noise_ratio(im_ss, im_sense, data_range=data_range)
print("pSNR for SENSE " + str(pSNR) + " dB")

plt.imshow(np.abs(sp.ifft(im_ss)), cmap='gray', vmin=0, vmax=1)
plt.show()